Skip to content

Conversation

@janimo
Copy link
Contributor

@janimo janimo commented Nov 5, 2025

Make get() use squeeze() instead of reshape(), just like get_on_dim() does. It should avoid a copy when the tensor is not contiguous.
Also use the same name for the index argument.

@ivarflakstad
Copy link
Member

Thanks!
This is interesting, because I can't at a glance tell if .reshape(&dims[1..]) is always identical to .squeeze(0) (after applying .narrow(0, index, 1) of course) . It makes sense that it is, and from looking at the code it does seem like it is. I also modified a unit test to check quickly and the results were equal, and they were, but I don't know if I've covered all edge cases.
Also important to note that even if it isn't always identical that doesn't necessarily mean that this is wrong, it could still be an improvement.

If you could provide some kind of additional proof that would be great.
For example add a test in tensor_tests.rs that uses get on a bunch of different tensors with expected results, and ensure the results are the same on main and on this branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants